{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-07-28T13:33:38.625117Z",
     "iopub.status.busy": "2021-07-28T13:33:38.624579Z",
     "iopub.status.idle": "2021-07-28T13:33:38.874966Z",
     "shell.execute_reply": "2021-07-28T13:33:38.874179Z",
     "shell.execute_reply.started": "2021-07-28T13:33:38.625069Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# pandas 数据集读取，dataframe形式的\n",
    "import pandas as pd\n",
    "# 文件读取\n",
    "import codecs\n",
    "train_df = pd.read_csv('train.csv', sep='\\t', names=['question1', 'question2', 'label'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-07-28T13:33:39.322107Z",
     "iopub.status.busy": "2021-07-28T13:33:39.321553Z",
     "iopub.status.idle": "2021-07-28T13:33:39.336131Z",
     "shell.execute_reply": "2021-07-28T13:33:39.335566Z",
     "shell.execute_reply.started": "2021-07-28T13:33:39.322058Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>question1</th>\n",
       "      <th>question2</th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>有哪些女明星被潜规则啦</td>\n",
       "      <td>哪些女明星被潜规则了</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>怎么支付宝绑定银行卡？</td>\n",
       "      <td>银行卡怎么绑定支付宝</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>请问这部电视剧叫什么名字</td>\n",
       "      <td>请问谁知道这部电视剧叫什么名字</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>泰囧完整版下载</td>\n",
       "      <td>エウテルペ完整版下载</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>在沧州市区哪家卖的盐焗鸡好吃？</td>\n",
       "      <td>沧州饭店哪家便宜又好吃又实惠</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         question1        question2  label\n",
       "0      有哪些女明星被潜规则啦       哪些女明星被潜规则了      1\n",
       "1      怎么支付宝绑定银行卡？       银行卡怎么绑定支付宝      1\n",
       "2     请问这部电视剧叫什么名字  请问谁知道这部电视剧叫什么名字      1\n",
       "3          泰囧完整版下载       エウテルペ完整版下载      0\n",
       "4  在沧州市区哪家卖的盐焗鸡好吃？   沧州饭店哪家便宜又好吃又实惠      0"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_df.head(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-11T09:35:08.358976Z",
     "start_time": "2021-03-11T09:35:08.347065Z"
    },
    "execution": {
     "iopub.execute_input": "2021-07-28T13:33:40.340164Z",
     "iopub.status.busy": "2021-07-28T13:33:40.339779Z",
     "iopub.status.idle": "2021-07-28T13:33:41.257145Z",
     "shell.execute_reply": "2021-07-28T13:33:41.256143Z",
     "shell.execute_reply.started": "2021-07-28T13:33:40.340131Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from sklearn.model_selection import train_test_split\n",
    "from torch.utils.data import Dataset, DataLoader, TensorDataset\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import random\n",
    "import re"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-11T09:35:09.263768Z",
     "start_time": "2021-03-11T09:35:09.227357Z"
    },
    "execution": {
     "iopub.execute_input": "2021-07-28T13:33:41.258151Z",
     "iopub.status.busy": "2021-07-28T13:33:41.257945Z",
     "iopub.status.idle": "2021-07-28T13:33:41.264379Z",
     "shell.execute_reply": "2021-07-28T13:33:41.263981Z",
     "shell.execute_reply.started": "2021-07-28T13:33:41.258134Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 划分为训练集和验证集\n",
    "# stratify 按照标签进行采样，训练集和验证部分同分布\n",
    "q1_train, q1_val, q2_train, q2_val, train_label, test_label =  train_test_split(\n",
    "    train_df['question1'].iloc[:], \n",
    "    train_df['question2'].iloc[:],\n",
    "    train_df['label'].iloc[:],\n",
    "    test_size=0.1, \n",
    "    stratify=train_df['label'].iloc[:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-07-28T13:33:41.265321Z",
     "iopub.status.busy": "2021-07-28T13:33:41.265130Z",
     "iopub.status.idle": "2021-07-28T13:33:41.340560Z",
     "shell.execute_reply": "2021-07-28T13:33:41.339402Z",
     "shell.execute_reply.started": "2021-07-28T13:33:41.265300Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# input_ids：字的编码\n",
    "# token_type_ids：标识是第一个句子还是第二个句子\n",
    "# attention_mask：标识是不是填充"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-11T09:35:19.935035Z",
     "start_time": "2021-03-11T09:35:09.908638Z"
    },
    "execution": {
     "iopub.execute_input": "2021-07-28T13:33:42.043946Z",
     "iopub.status.busy": "2021-07-28T13:33:42.043420Z",
     "iopub.status.idle": "2021-07-28T13:33:47.413910Z",
     "shell.execute_reply": "2021-07-28T13:33:47.412843Z",
     "shell.execute_reply.started": "2021-07-28T13:33:42.043899Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/lyz/.local/lib/python3.6/site-packages/requests/__init__.py:91: RequestsDependencyWarning: urllib3 (1.26.6) or chardet (3.0.4) doesn't match a supported version!\n",
      "  RequestsDependencyWarning)\n"
     ]
    }
   ],
   "source": [
    "# pip install transformers\n",
    "# transformers bert相关的模型使用和加载\n",
    "from transformers import BertTokenizer\n",
    "# 分词器，词典\n",
    "\n",
    "tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')\n",
    "train_encoding = tokenizer(list(q1_train), list(q2_train), \n",
    "                           truncation=True, padding=True, max_length=100)\n",
    "val_encoding = tokenizer(list(q1_val), list(q2_val), \n",
    "                          truncation=True, padding=True, max_length=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-11T09:35:43.578135Z",
     "start_time": "2021-03-11T09:35:43.571452Z"
    },
    "execution": {
     "iopub.execute_input": "2021-07-28T13:33:47.415038Z",
     "iopub.status.busy": "2021-07-28T13:33:47.414865Z",
     "iopub.status.idle": "2021-07-28T13:33:47.419425Z",
     "shell.execute_reply": "2021-07-28T13:33:47.418993Z",
     "shell.execute_reply.started": "2021-07-28T13:33:47.415021Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 数据集读取\n",
    "class XFeiDataset(Dataset):\n",
    "    def __init__(self, encodings, labels):\n",
    "        self.encodings = encodings\n",
    "        self.labels = labels\n",
    "    \n",
    "    # 读取单个样本\n",
    "    def __getitem__(self, idx):\n",
    "        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}\n",
    "        item['labels'] = torch.tensor(int(self.labels[idx]))\n",
    "        return item\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.labels)\n",
    "\n",
    "train_dataset = XFeiDataset(train_encoding, list(train_label))\n",
    "val_dataset = XFeiDataset(val_encoding, list(test_label))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-11T09:35:44.110121Z",
     "start_time": "2021-03-11T09:35:44.104871Z"
    },
    "execution": {
     "iopub.execute_input": "2021-07-28T13:33:47.420316Z",
     "iopub.status.busy": "2021-07-28T13:33:47.420156Z",
     "iopub.status.idle": "2021-07-28T13:33:47.825364Z",
     "shell.execute_reply": "2021-07-28T13:33:47.824125Z",
     "shell.execute_reply.started": "2021-07-28T13:33:47.420301Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 精度计算\n",
    "def flat_accuracy(preds, labels):\n",
    "    pred_flat = np.argmax(preds, axis=1).flatten()\n",
    "    labels_flat = labels.flatten()\n",
    "    return np.sum(pred_flat == labels_flat) / len(labels_flat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-11T09:58:49.027161Z",
     "start_time": "2021-03-11T09:58:45.317009Z"
    },
    "execution": {
     "iopub.execute_input": "2021-07-28T13:33:47.827652Z",
     "iopub.status.busy": "2021-07-28T13:33:47.827148Z",
     "iopub.status.idle": "2021-07-28T13:33:55.429720Z",
     "shell.execute_reply": "2021-07-28T13:33:55.428825Z",
     "shell.execute_reply.started": "2021-07-28T13:33:47.827606Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForNextSentencePrediction: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']\n",
      "- This IS expected if you are initializing BertForNextSentencePrediction from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForNextSentencePrediction from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    }
   ],
   "source": [
    "from transformers import BertForNextSentencePrediction, AdamW, get_linear_schedule_with_warmup\n",
    "model = BertForNextSentencePrediction.from_pretrained('bert-base-chinese')\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "model.to(device)\n",
    "\n",
    "# 单个读取到批量读取\n",
    "train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)\n",
    "val_dataloader = DataLoader(val_dataset, batch_size=4, shuffle=True)\n",
    "\n",
    "# 优化方法\n",
    "optim = AdamW(model.parameters(), lr=1e-5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-11T09:44:22.077501Z",
     "start_time": "2021-03-11T09:39:16.473609Z"
    },
    "execution": {
     "iopub.execute_input": "2021-07-28T13:30:11.012880Z",
     "iopub.status.busy": "2021-07-28T13:30:11.012294Z",
     "iopub.status.idle": "2021-07-28T13:33:07.672660Z",
     "shell.execute_reply": "2021-07-28T13:33:07.671378Z",
     "shell.execute_reply.started": "2021-07-28T13:30:11.012830Z"
    },
    "scrolled": true,
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "------------Epoch: 0 ----------------\n",
      "epoth: 0, iter_num: 100, loss: 0.0017, 8.89%\n",
      "epoth: 0, iter_num: 200, loss: 0.0006, 17.78%\n",
      "epoth: 0, iter_num: 300, loss: 0.0004, 26.67%\n",
      "epoth: 0, iter_num: 400, loss: 0.0013, 35.56%\n",
      "epoth: 0, iter_num: 500, loss: 0.0002, 44.44%\n",
      "epoth: 0, iter_num: 600, loss: 1.2181, 53.33%\n",
      "epoth: 0, iter_num: 700, loss: 0.0142, 62.22%\n",
      "epoth: 0, iter_num: 800, loss: 1.5131, 71.11%\n",
      "epoth: 0, iter_num: 900, loss: 0.0013, 80.00%\n",
      "epoth: 0, iter_num: 1000, loss: 3.3881, 88.89%\n",
      "epoth: 0, iter_num: 1100, loss: 0.0013, 97.78%\n",
      "Epoch: 0, Average training loss: 0.3522\n",
      "Accuracy: 0.8960\n",
      "Average testing loss: 0.5002\n",
      "-------------------------------\n",
      "------------Epoch: 1 ----------------\n",
      "epoth: 1, iter_num: 100, loss: 0.0004, 8.89%\n",
      "epoth: 1, iter_num: 200, loss: 0.0003, 17.78%\n",
      "epoth: 1, iter_num: 300, loss: 0.0019, 26.67%\n",
      "epoth: 1, iter_num: 400, loss: 0.0006, 35.56%\n",
      "epoth: 1, iter_num: 500, loss: 0.0002, 44.44%\n",
      "epoth: 1, iter_num: 600, loss: 0.0013, 53.33%\n",
      "epoth: 1, iter_num: 700, loss: 1.3468, 62.22%\n",
      "epoth: 1, iter_num: 800, loss: 2.4434, 71.11%\n",
      "epoth: 1, iter_num: 900, loss: 0.0000, 80.00%\n",
      "epoth: 1, iter_num: 1000, loss: 0.0057, 88.89%\n",
      "epoth: 1, iter_num: 1100, loss: 0.0020, 97.78%\n",
      "Epoch: 1, Average training loss: 0.1870\n",
      "Accuracy: 0.8720\n",
      "Average testing loss: 0.8334\n",
      "-------------------------------\n",
      "------------Epoch: 2 ----------------\n",
      "epoth: 2, iter_num: 100, loss: 0.0001, 8.89%\n",
      "epoth: 2, iter_num: 200, loss: 0.0000, 17.78%\n",
      "epoth: 2, iter_num: 300, loss: 0.0000, 26.67%\n",
      "epoth: 2, iter_num: 400, loss: 0.0015, 35.56%\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-20-fcc2bfe0fe8c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     56\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mepoch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     57\u001b[0m     \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"------------Epoch: %d ----------------\"\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mepoch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 58\u001b[0;31m     \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     59\u001b[0m     \u001b[0mvalidation\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     60\u001b[0m     \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34mf'model_{epoch}.pt'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-20-fcc2bfe0fe8c>\u001b[0m in \u001b[0;36mtrain\u001b[0;34m()\u001b[0m\n\u001b[1;32m     11\u001b[0m         \u001b[0mattention_mask\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'attention_mask'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     12\u001b[0m         \u001b[0mlabels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'labels'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m         \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_ids\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattention_mask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mattention_mask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlabels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     14\u001b[0m         \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     15\u001b[0m         \u001b[0mtotal_train_loss\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    725\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    726\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    728\u001b[0m         for hook in itertools.chain(\n\u001b[1;32m    729\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.6/site-packages/transformers/models/bert/modeling_bert.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, **kwargs)\u001b[0m\n\u001b[1;32m   1429\u001b[0m             \u001b[0moutput_attentions\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_attentions\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1430\u001b[0m             \u001b[0moutput_hidden_states\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_hidden_states\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1431\u001b[0;31m             \u001b[0mreturn_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreturn_dict\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1432\u001b[0m         )\n\u001b[1;32m   1433\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    725\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    726\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    728\u001b[0m         for hook in itertools.chain(\n\u001b[1;32m    729\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.6/site-packages/transformers/models/bert/modeling_bert.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m    979\u001b[0m             \u001b[0moutput_attentions\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_attentions\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    980\u001b[0m             \u001b[0moutput_hidden_states\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_hidden_states\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 981\u001b[0;31m             \u001b[0mreturn_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreturn_dict\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    982\u001b[0m         )\n\u001b[1;32m    983\u001b[0m         \u001b[0msequence_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mencoder_outputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    725\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    726\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    728\u001b[0m         for hook in itertools.chain(\n\u001b[1;32m    729\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.6/site-packages/transformers/models/bert/modeling_bert.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m    573\u001b[0m                     \u001b[0mencoder_attention_mask\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    574\u001b[0m                     \u001b[0mpast_key_value\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 575\u001b[0;31m                     \u001b[0moutput_attentions\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    576\u001b[0m                 )\n\u001b[1;32m    577\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    725\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    726\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    728\u001b[0m         for hook in itertools.chain(\n\u001b[1;32m    729\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.6/site-packages/transformers/models/bert/modeling_bert.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)\u001b[0m\n\u001b[1;32m    459\u001b[0m             \u001b[0mhead_mask\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    460\u001b[0m             \u001b[0moutput_attentions\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_attentions\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 461\u001b[0;31m             \u001b[0mpast_key_value\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself_attn_past_key_value\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    462\u001b[0m         )\n\u001b[1;32m    463\u001b[0m         \u001b[0mattention_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself_attention_outputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    725\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    726\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    728\u001b[0m         for hook in itertools.chain(\n\u001b[1;32m    729\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.6/site-packages/transformers/models/bert/modeling_bert.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)\u001b[0m\n\u001b[1;32m    392\u001b[0m             \u001b[0mencoder_attention_mask\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    393\u001b[0m             \u001b[0mpast_key_value\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 394\u001b[0;31m             \u001b[0moutput_attentions\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    395\u001b[0m         )\n\u001b[1;32m    396\u001b[0m         \u001b[0mattention_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself_outputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhidden_states\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    725\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    726\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    728\u001b[0m         for hook in itertools.chain(\n\u001b[1;32m    729\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.6/site-packages/transformers/models/bert/modeling_bert.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)\u001b[0m\n\u001b[1;32m    323\u001b[0m             \u001b[0mattention_probs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mattention_probs\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mhead_mask\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    324\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 325\u001b[0;31m         \u001b[0mcontext_layer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mattention_probs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue_layer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    326\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    327\u001b[0m         \u001b[0mcontext_layer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcontext_layer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpermute\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcontiguous\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# 训练函数\n",
    "def train():\n",
    "    model.train()\n",
    "    total_train_loss = 0\n",
    "    iter_num = 0\n",
    "    total_iter = len(train_loader)\n",
    "    for batch in train_loader:\n",
    "        # 正向传播\n",
    "        optim.zero_grad()\n",
    "        input_ids = batch['input_ids'].to(device)\n",
    "        attention_mask = batch['attention_mask'].to(device)\n",
    "        labels = batch['labels'].to(device)\n",
    "        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)\n",
    "        loss = outputs[0]\n",
    "        total_train_loss += loss.item()\n",
    "        \n",
    "        # 反向梯度信息\n",
    "        loss.backward()\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
    "        \n",
    "        # 参数更新\n",
    "        optim.step()\n",
    "\n",
    "        iter_num += 1\n",
    "        if(iter_num % 100==0):\n",
    "            print(\"epoth: %d, iter_num: %d, loss: %.4f, %.2f%%\" % (epoch, iter_num, loss.item(), iter_num/total_iter*100))\n",
    "        \n",
    "    print(\"Epoch: %d, Average training loss: %.4f\"%(epoch, total_train_loss/len(train_loader)))\n",
    "    \n",
    "def validation():\n",
    "    model.eval()\n",
    "    total_eval_accuracy = 0\n",
    "    total_eval_loss = 0\n",
    "    for batch in val_dataloader:\n",
    "        with torch.no_grad():\n",
    "            # 正常传播\n",
    "            input_ids = batch['input_ids'].to(device)\n",
    "            attention_mask = batch['attention_mask'].to(device)\n",
    "            labels = batch['labels'].to(device)\n",
    "            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)\n",
    "        \n",
    "        loss = outputs[0]\n",
    "        logits = outputs[1]\n",
    "\n",
    "        total_eval_loss += loss.item()\n",
    "        logits = logits.detach().cpu().numpy()\n",
    "        label_ids = labels.to('cpu').numpy()\n",
    "        total_eval_accuracy += flat_accuracy(logits, label_ids)\n",
    "        \n",
    "    avg_val_accuracy = total_eval_accuracy / len(val_dataloader)\n",
    "    print(\"Accuracy: %.4f\" % (avg_val_accuracy))\n",
    "    print(\"Average testing loss: %.4f\"%(total_eval_loss/len(val_dataloader)))\n",
    "    print(\"-------------------------------\")\n",
    "    \n",
    "\n",
    "for epoch in range(5):\n",
    "    print(\"------------Epoch: %d ----------------\" % epoch)\n",
    "    train()\n",
    "    validation()\n",
    "    torch.save(model.state_dict(), f'model_{epoch}.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-07-28T13:33:55.430784Z",
     "iopub.status.busy": "2021-07-28T13:33:55.430582Z",
     "iopub.status.idle": "2021-07-28T13:33:55.622292Z",
     "shell.execute_reply": "2021-07-28T13:33:55.621839Z",
     "shell.execute_reply.started": "2021-07-28T13:33:55.430768Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# torch.save(model.state_dict(), f'model_{epoch}.pt')\n",
    "model.load_state_dict(torch.load(f'model_0.pt'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-07-28T13:33:58.553826Z",
     "iopub.status.busy": "2021-07-28T13:33:58.553224Z",
     "iopub.status.idle": "2021-07-28T13:33:58.579133Z",
     "shell.execute_reply": "2021-07-28T13:33:58.578516Z",
     "shell.execute_reply.started": "2021-07-28T13:33:58.553778Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>question1</th>\n",
       "      <th>question2</th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>玩梦幻西游能赚钱吗</td>\n",
       "      <td>梦幻西游2不花钱能玩吗</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>夏天去什么地方旅游好</td>\n",
       "      <td>夏季去什么地方旅游最好（国内的地方）</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>为什么梦幻西游网站打不开</td>\n",
       "      <td>为什么下载梦幻西游游戏补丁网页打不开</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>这对双胞胎像不</td>\n",
       "      <td>像这样的可爱卡通图片要男的谢了</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>免费网络游戏都有哪些?</td>\n",
       "      <td>有哪些免费的网络游戏</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      question1           question2  label\n",
       "0     玩梦幻西游能赚钱吗         梦幻西游2不花钱能玩吗    0.0\n",
       "1    夏天去什么地方旅游好  夏季去什么地方旅游最好（国内的地方）    0.0\n",
       "2  为什么梦幻西游网站打不开  为什么下载梦幻西游游戏补丁网页打不开    0.0\n",
       "3       这对双胞胎像不     像这样的可爱卡通图片要男的谢了    0.0\n",
       "4   免费网络游戏都有哪些?          有哪些免费的网络游戏    0.0"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_df = pd.read_csv('test.csv', sep='\\t', names=['question1', 'question2', 'label'])\n",
    "test_df['label'] = test_df['label'].fillna(0)\n",
    "test_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-07-28T13:34:00.852096Z",
     "iopub.status.busy": "2021-07-28T13:34:00.851516Z",
     "iopub.status.idle": "2021-07-28T13:34:01.625504Z",
     "shell.execute_reply": "2021-07-28T13:34:01.624460Z",
     "shell.execute_reply.started": "2021-07-28T13:34:00.852048Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "test_encoding = tokenizer(list(test_df['question1']), list(test_df['question2']), \n",
    "                          truncation=True, padding=True, max_length=100)\n",
    "\n",
    "test_dataset = XFeiDataset(test_encoding, list(test_df['label']))\n",
    "test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-07-28T13:34:08.156076Z",
     "iopub.status.busy": "2021-07-28T13:34:08.155495Z",
     "iopub.status.idle": "2021-07-28T13:34:08.164899Z",
     "shell.execute_reply": "2021-07-28T13:34:08.163887Z",
     "shell.execute_reply.started": "2021-07-28T13:34:08.156028Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def predict():\n",
    "    model.eval()\n",
    "    test_predict = []\n",
    "    for batch in test_dataloader:\n",
    "        with torch.no_grad():\n",
    "            # 正常传播\n",
    "            input_ids = batch['input_ids'].to(device)\n",
    "            attention_mask = batch['attention_mask'].to(device)\n",
    "            labels = batch['labels'].to(device)\n",
    "            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)\n",
    "        \n",
    "        loss = outputs[0]\n",
    "        logits = outputs[1]\n",
    "\n",
    "        logits = logits.detach().cpu().numpy()\n",
    "        label_ids = labels.to('cpu').numpy()\n",
    "        test_predict += list(np.argmax(logits, axis=1).flatten())\n",
    "        \n",
    "    return test_predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-07-28T13:34:09.546620Z",
     "iopub.status.busy": "2021-07-28T13:34:09.546059Z",
     "iopub.status.idle": "2021-07-28T13:34:22.717710Z",
     "shell.execute_reply": "2021-07-28T13:34:22.717278Z",
     "shell.execute_reply.started": "2021-07-28T13:34:09.546569Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "test_label = predict()\n",
    "pd.DataFrame({'label':test_label}).to_csv('submit.csv', index=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-07-28T13:15:30.634194Z",
     "iopub.status.busy": "2021-07-28T13:15:30.633952Z",
     "iopub.status.idle": "2021-07-28T13:15:30.641406Z",
     "shell.execute_reply": "2021-07-28T13:15:30.640970Z",
     "shell.execute_reply.started": "2021-07-28T13:15:30.634178Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-07-29T10:18:18.692632Z",
     "iopub.status.busy": "2021-07-29T10:18:18.692003Z",
     "iopub.status.idle": "2021-07-29T10:18:23.335152Z",
     "shell.execute_reply": "2021-07-29T10:18:23.334666Z",
     "shell.execute_reply.started": "2021-07-29T10:18:18.692585Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "from keras.applications.vgg16 import VGG16\n",
    "from keras.preprocessing import image\n",
    "from keras.applications.resnet50 import preprocess_input, decode_predictions\n",
    "import numpy as np\n",
    "\n",
    "model = VGG16(weights=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
